# Copyright 2019 Mike D'Arcy. All rights reserved.
# This file is part of Drexo.

import collections
import copy
import csv
import io
import json
import os
import random
import re
import shutil
import subprocess
import sys
import time

class RunnerEngine:
	def __init__(self, gpu_ids=None, gpu_block_filename='/tmp/drexo_block_gpus'):
		if not gpu_ids:
			gpu_ids = self._get_default_gpu_ids()
		self.gpu_processes = {gpu_id: None for gpu_id in gpu_ids}

		self.job_queue = collections.deque([])
		self.blocked_gpus = set()
		self.gpu_block_filename = gpu_block_filename

		raw_uuidinfo = subprocess.check_output(['nvidia-smi', '--query-gpu=index,uuid', '--format=csv'])
		reader = csv.DictReader(io.StringIO(raw_uuidinfo.decode('ascii', errors='ignore')), skipinitialspace=True)
		self._gpu_uuid_map = dict()
		for row in reader:
			self._gpu_uuid_map[row['uuid']] = row['index']

	def _get_default_gpu_ids(self):
		gpu_ids = [str(i) for i in range(self.get_num_gpus())]
		if 'CUDA_VISIBLE_DEVICES' in os.environ:
			gpu_ids = [x for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
		return gpu_ids

	def get_num_gpus(self):
		return len([x for x in subprocess.check_output(['nvidia-smi', '-L']).decode('ascii', errors='ignore').split('\n') if len(x) > 0])

	def wait_for_free_gpu(self, poll_interval=5, timeout=-1):
		start_time = time.perf_counter()
		while timeout < 0 or (time.perf_counter() - start_time) < timeout:
			self._refresh_processes()
			self._refresh_blocked_gpus()

			for gpu_id in self.gpu_processes:
				if self.gpu_processes[gpu_id] is None and gpu_id not in self.blocked_gpus:
					return gpu_id

			time.sleep(poll_interval)

		return None

	def _on_process_finished(self, runobj):
		if 'logfile_stdout' in runobj:
			runobj['logfile_stdout'].close()
		if 'logfile_stderr' in runobj:
			runobj['logfile_stderr'].close()

		self.gpu_processes[runobj['gpu_id']] = None
		runobj['runconfig'].callback(runobj['popen_obj'].returncode)

	def wait_on_process(self, runobj):
		runobj['popen_obj'].wait()
		self._on_process_finished(runobj)

	def _refresh_processes(self):
		for gpu_id in self.gpu_processes:
			if not self.gpu_processes[gpu_id]:
				continue
			exit_status = self.gpu_processes[gpu_id]['popen_obj'].poll()
			if exit_status is not None:
				self._on_process_finished(self.gpu_processes[gpu_id])

	def _refresh_blocked_gpus(self):
		self.blocked_gpus = set()
		if os.path.exists(self.gpu_block_filename):
			with open(self.gpu_block_filename, 'r') as f:
				for line in f.readlines():
					if line.strip() in self.gpu_processes:
						self.blocked_gpus.add(line.strip())

		# Check nvidia-smi for other processes
		raw_uuidinfo = subprocess.check_output(['nvidia-smi', '--query-compute-apps=pid,gpu_uuid', '--format=csv'])
		reader = csv.DictReader(io.StringIO(raw_uuidinfo.decode('ascii', errors='ignore')), skipinitialspace=True)
		for row in reader:
			if row['gpu_uuid'] not in self._gpu_uuid_map:
				continue
			self.blocked_gpus.add(self._gpu_uuid_map[row['gpu_uuid']])

	def add_job(self, new_job, front=False):
		if front:
			self.job_queue.appendleft(new_job)
		else:
			self.job_queue.append(new_job)

	def is_done(self):
		return len(self.job_queue) == 0 and all(self.gpu_processes[x] is None for x in self.gpu_processes)

	def _run_job(self, job):
		if job.cmd_args is None:
			if job.callback:
				job.callback()
			return

		gpu_id = self.wait_for_free_gpu()
		env = os.environ.copy()
		env['CUDA_VISIBLE_DEVICES'] = gpu_id

		print('CUDA_VISIBLE_DEVICES={}'.format(gpu_id), job.cmd_args)

		runobj = dict()

		#runobj['logfile_stdout'] = open(os.path.join(job.local_output_dir, 'stdout.log'), 'w')
		#runobj['logfile_stderr'] = open(os.path.join(job.local_output_dir, 'stderr.log'), 'w')

		runobj['start_time'] = time.time()
		runobj['gpu_id'] = gpu_id
		runobj['runconfig'] = job
		runobj['popen_obj'] = subprocess.Popen(job.cmd_args, env=env) #, stdout=runobj['logfile_stdout'], stderr=runobj['logfile_stderr'])
		self.gpu_processes[gpu_id] = runobj

		with open(os.path.join(job.local_output_dir, 'execution_info.json'), 'w') as f:
			json.dump({'start_time': runobj['start_time'], 'gpu_id': gpu_id, 'environment': env}, f)

	def run(self):
		while not self.is_done():
			if len(self.job_queue) != 0:
				job = self.job_queue.popleft()
				self._run_job(job)
			else:
				self._refresh_processes()
				time.sleep(5)

class RunConfig:
	def __init__(self, cmd_args, local_output_dir, callback):
		self.cmd_args = cmd_args
		self.local_output_dir = local_output_dir
		self.callback = callback

class RunManager:
	def __init__(self, base_output_dir, params_to_args_func, params_compare_func=None, job_finished_func=None, runner=None):
		"""Constructor.
		base_output_dir is the path to an output directory that will
		hold all of the subdirectories with job execution information.

		params_to_args_func(params, output_dir, seed) should be
		a function that takes `params` (a dict, in whatever format the
		caller expects to be passed to the queuing function of this
		class), the path to an output_dir (a string) and a random seed
		(int) and returns a list representing a complete command line
		(i.e., of the same format accepted by the `subprocess` module's
		`Popen` function). 

		params_compare_func is an optional function that takes two sets
		of params and returns True if they are effectively equivalent
		for purposes of caching (so params found "equal" to any
		previously-run job's params will be rejected from the queuing
		function). If this argument is None, caching will be disabled.

		job_finished_func is an optional function that takes
		a full_runconfig and returns nothing. This function is run when
		a job finishes (but not when a job is skipped due to caching)
		and can be used to do things like cleaning up large output
		files or compiling results from output data.

		runner is a RunnerEngine instance used to actually run the jobs
		passed to the queuing function
		"""
		self.base_output_dir = base_output_dir
		self.params_to_args_func = params_to_args_func
		self.params_compare_func = params_compare_func
		if self.params_compare_func is None:
			self.params_compare_func = lambda params1, params2: False

		self.job_finished_func = job_finished_func

		self.runner = runner
		if self.runner is None:
			self.runner = RunnerEngine()

		self._errors_encountered = []

		self.next_runid = 0

		self._finished_filename = 'run_finished.json'

		self._cached_runconfigs = self._gen_initial_cache()
		self._reserved_dirs = [os.path.normpath(os.path.realpath(os.path.dirname(x))) for x in self._cached_runconfigs]

	def new_runid(self):
		self.next_runid += 1
		return self.next_runid-1

	def _record_error(self, errstr, errtype='unspecified', errobj=None, extra_info=None):
		blob = {
			'description': errstr,
			'type': errtype,
		}
		if errobj is not None:
			blob['thrown_obj'] = errobj
		if extra_info is not None:
			blob['extra_info'] = extra_info
		self._errors_encountered.append(blob)

	def get_errors(self):
		return self._errors_encountered

	def add_to_run_queue(self, orig_params, callback, front=False):
		"""The queuing function."""

		params = copy.deepcopy(orig_params)

		full_runconfig = {
			'params': params,
			'seed': str(random.randint(0, 10000000)),
		}

		cached_config = self._lookup_cached_equivalent(params)
		if cached_config is not None:
			print('Found cached config in runid={}: Adding dummy to queue instead of the real job.'.format(cached_config['runid']))
			def sub_callback():
				nonlocal cached_config
				callback(cached_config['local_output_dir'], cached_config['seed'])
			self.runner.add_job(RunConfig(None, None, sub_callback), front=front)
			return

		runid = -1
		local_output_dir = None
		while not local_output_dir:
			runid = self.new_runid()
			local_output_dir = os.path.join(self.base_output_dir, 'runid_{:05d}'.format(runid))
			if os.path.normpath(os.path.realpath(local_output_dir)) in self._reserved_dirs:
				local_output_dir = None

		print('No cached params for runid={}. Queuing for execution.'.format(runid))

		full_runconfig['runid'] = runid
		full_runconfig['local_output_dir'] = local_output_dir

		job_output_dir = os.path.join(local_output_dir, 'output')
		full_runconfig['job_output_dir'] = job_output_dir

		# Fully reset the output dir if it exists, so previous runs
		# cannot interfere with new runs
		if os.path.exists(local_output_dir):
			shutil.rmtree(local_output_dir)
		os.makedirs(local_output_dir)
		os.makedirs(job_output_dir)

		full_runconfig['final_args'] = self.params_to_args_func(params, full_runconfig['job_output_dir'], full_runconfig['seed'])

		with open(os.path.join(local_output_dir, 'runconfig.json'), 'w') as f:
			json.dump(full_runconfig, f)

		def sub_callback(exit_status):
			nonlocal full_runconfig

			if self.job_finished_func is not None:
				self.job_finished_func(copy.copy(full_runconfig))

			if exit_status != 0:
				self._record_error('Nonzero exit status', errtype='process_exit_status', extra_info={'runid': full_runconfig['runid']})

			finish_obj = {
				'finished': True,
				'exit_status': exit_status,
			}
			with open(os.path.join(full_runconfig['local_output_dir'], self._finished_filename), 'w') as f:
				json.dump(finish_obj, f)

			callback(full_runconfig['local_output_dir'], full_runconfig['seed'])

		self.runner.add_job(RunConfig(full_runconfig['final_args'], full_runconfig['local_output_dir'], sub_callback), front=front)

	def _gen_initial_cache(self):
		if not os.path.exists(self.base_output_dir):
			return []

		all_cached_runconfigs = []

		subdirs = sorted(next(os.walk(self.base_output_dir))[1])
		for subdir in subdirs:
			finished_json_path = os.path.join(self.base_output_dir, subdir, self._finished_filename)
			runconfig_path = os.path.join(self.base_output_dir, subdir, 'runconfig.json')

			if not (re.match(r'^runid_[0-9]+$', subdir) and os.path.exists(finished_json_path) and os.path.exists(runconfig_path)):
				continue

			with open(finished_json_path, 'r') as f:
				tmp = json.load(f)
				if 'finished' not in tmp or not tmp['finished']:
					continue
				if 'exit_status' not in tmp or tmp['exit_status'] != 0:
					continue

			all_cached_runconfigs.append(runconfig_path)
		return all_cached_runconfigs

	def _lookup_cached_equivalent(self, params):
		for i in range(len(self._cached_runconfigs)):
			runconfig_path = self._cached_runconfigs[i]
			if not os.path.exists(runconfig_path):
				return None
			with open(runconfig_path, 'r') as f:
				runconfig = json.load(f)

			# Check if runconfig matches params
			if self.params_compare_func(params, runconfig['params']):
				self._cached_runconfigs.remove(runconfig_path)
				return runconfig
		return None

